package vulnerability

import (
	"encoding/json"

	"github.com/aquasecurity/trivy/pkg/db"
	bolt "github.com/etcd-io/bbolt"
	"golang.org/x/xerrors"
)

const (
	rootBucket = "vulnerability"
)

func Put(tx *bolt.Tx, cveID, source string, vuln Vulnerability) error {
	root, err := tx.CreateBucketIfNotExists([]byte(rootBucket))
	if err != nil {
		return err
	}
	return db.Put(root, cveID, source, vuln)
}

func Update(cveID, source string, vuln Vulnerability) error {
	return db.Update(rootBucket, cveID, source, vuln)
}

func BatchUpdate(fn func(b *bolt.Bucket) error) error {
	return db.BatchUpdate(func(tx *bolt.Tx) error {
		root, err := tx.CreateBucketIfNotExists([]byte(rootBucket))
		if err != nil {
			return err
		}
		return fn(root)
	})
}

func Get(cveID string) (map[string]Vulnerability, error) {
	values, err := db.ForEach(rootBucket, cveID)
	if err != nil {
		return nil, xerrors.Errorf("error in NVD get: %w", err)
	}
	if len(values) == 0 {
		return nil, nil
	}

	vulns := map[string]Vulnerability{}
	for source, value := range values {
		var vuln Vulnerability
		if err = json.Unmarshal(value, &vuln); err != nil {
			return nil, xerrors.Errorf("failed to unmarshal Vulnerability JSON: %w", err)
		}
		vulns[source] = vuln
	}
	return vulns, nil
}
